import ConfigSpace

from hpo.hpo_base import *
from smac.facade.smac_hpo_facade import SMAC4HPO
from smac.facade.smac_mf_facade import SMAC4MF
from smac.scenario.scenario import Scenario
from smac.initial_design.random_configuration_design import RandomConfigurations
import numpy as np
import os
import pickle
from datetime import datetime
from hpo.utils import get_reward_from_trajectory


class SMAC(HyperparameterOptimizer):
    """
    Implementation of the sequential model-based algorithm configuration optimizer
    Master repository:
        https://github.com/automl/SMAC3
    """

    def __init__(self, env,
                 log_dir,
                 max_timesteps: int = None,
                 max_iters: int = 100, batch_size: int = 1,
                 n_repetitions: int = 1,
                 multi_fidelity: bool = False,
                 multi_fidelity_config: dict = None,
                 rng_state: int = 42,
                 use_reward='final',
                 n_init: int = None,
                 anneal_lr: bool = False,
                 dummy=False):
        super().__init__(env, max_iters, batch_size, n_repetitions)
        self.max_timesteps = max_timesteps if max_timesteps is not None else env.default_env_args['num_timesteps']
        self.multi_fidelity_config = multi_fidelity_config
        self.multi_fidelity = multi_fidelity
        self.dim = len(env.config_space.get_hyperparameters())
        self.n_init = n_init if n_init is not None and n_init > 0 else 2 * self.dim + 1

        self.log_dir = log_dir
        self.dummy = dummy
        self.anneal_lr = anneal_lr

        if multi_fidelity_config is None:
            multi_fidelity_config = {}
        self.multi_fidelity_config = {
            'initial_budget': np.round(0.1 * self.max_timesteps).astype(np.int),
            'max_budget': int(self.max_timesteps),
            'eta': 3
        }
        # overwrite with new values, if any of these is supplied by the user.
        if 'init_budget' in multi_fidelity_config.keys() and multi_fidelity_config['init_budget'] is not None:
            if isinstance(multi_fidelity_config['init_budget'], int) and multi_fidelity_config['init_budget'] >= 1:
                self.multi_fidelity_config['init_budget'] = multi_fidelity_config['init_budget']
            elif isinstance(multi_fidelity_config['init_budget'], float):
                self.multi_fidelity_config['init_budget'] = np.round(
                    float(self.multi_fidelity_config['init_budget']) * self.max_timesteps).astype(np.int)
        if 'max_budget' in multi_fidelity_config.keys() and multi_fidelity_config['max_budget'] is not None:
            if isinstance(multi_fidelity_config['max_budget'], int) and multi_fidelity_config['max_budget'] >= 1:
                self.multi_fidelity_config['max_budget'] = multi_fidelity_config['max_budget']
            elif isinstance(multi_fidelity_config['max_budget'], float):
                self.multi_fidelity_config['max_budget'] = np.round(
                    float(self.multi_fidelity_config['max_budget']) * self.max_timesteps).astype(np.int)
        if 'eta' in multi_fidelity_config.keys() and multi_fidelity_config['eta']:
            self.multi_fidelity_config['eta'] = multi_fidelity_config['eta']

        scenario = Scenario(
            {
                'run_obj': 'quality',
                'cs': self.env.config_space,
                'deterministic': 'true',
                'runcount-limit': self.max_iters,
                # 'n_seeds': self.n_repetitions,
                # for non-deterministic objective function, this is the number of repetitions for each config during evaluation
            })
        time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
        scenario.output_dir = os.path.join(self.log_dir, 'smac_logs', time_string)
        if self.multi_fidelity:
            self.smac = SMAC4MF(
                scenario=scenario,
                rng=np.random.RandomState(rng_state),
                intensifier_kwargs=self.multi_fidelity_config,
                tae_runner=self._obj_func_handle,
                initial_design=RandomConfigurations,        # default is sobol which is problematic for conditional hyperparams
                initial_design_kwargs={'init_budget': self.n_init}
            )
        else:
            self.smac = SMAC4HPO(
                scenario=scenario,
                rng=np.random.RandomState(rng_state),
                tae_runner=self._obj_func_handle,
                initial_design=RandomConfigurations,
                initial_design_kwargs={'init_budget': self.n_init}
            )
        self.cur_iters = 0
        # scenario.output_dir_for_this_run = scenario.output_dir_for_this_run.replace(':', '_')

        # whether to use the reward of the final logged timestep, or use some sort of average
        if use_reward == 'final': self.use_last_n_reward = 1
        else: self.use_last_n_reward = use_reward if isinstance(use_reward, int) else 1

    def run(self,):
        def_value = self.smac.get_tae_runner().run(
            config=self.env.config_space.get_default_configuration(),
            budget=int(self.max_timesteps),
        )[1]
        print(f'Value for default configuration = {def_value}')
        incumbent = self.smac.optimize()

        # try:
        #     incumbent = self.smac.optimize()
        # finally:
        #     incumbent = self.smac.solver.incumbent
        inc_value = self.smac.get_tae_runner().run(
            config=incumbent,
            budget=int(self.max_timesteps),
        )[1]
        print(f'Optimised value = {inc_value}')
        print(f'Saving intermediate results to {os.path.join(self.log_dir, "stats.pkl")}')
        pickle.dump([self.X, self.y], open(os.path.join(self.log_dir, 'stats.pkl'), 'wb'))

        return self.X, self.y

    def _obj_func_handle(self, config, budget):

        # if budget is None or budget == 0: # seems like budget=0 is the default option when calling the handle internally from SMAC3?
        #     budget = self.max_timesteps
        # return np.random.random()
        # trajectory = self.env.train_single(config, exp_idx=self.cur_iters, num_timesteps=budget)['y']
        # print(self.max_timesteps)
        trajectory = self.env.train_batch([config], seeds=[self.env.seed], nums_timesteps=[self.max_timesteps], max_parallel=1, anneal_lr=self.anneal_lr)[0]
        reward = -get_reward_from_trajectory(np.array(trajectory['y'], dtype=np.float), use_last_fraction=self.use_last_n_reward)
        # reward = trajectory[-self.use_last_n_reward:]
        self.cur_iters += 1
        self.X.append(config)
        self.y.append(reward)
        return reward
